# -*- coding:utf-8 -*-
'''
ref:
http://nbviewer.jupyter.org/github/joyofdata/joyofdata-articles/blob/master/deeplearning-with-caffe/Neural-Networks-with-Caffe-on-the-GPU.ipynb
https://github.com/Franck-Dernoncourt/caffe_demos/blob/master/caffe_vs_sklearn_lr/caffe_vs_sklearn_lr.py
'''
import subprocess
import platform
import sys

sys.path.append("/home/dell/Rui/caffe/python/")
import caffe

caffe.set_mode_gpu()
# import lmdb
import h5py
from sklearn.cross_validation import StratifiedShuffleSplit
import pandas as pd
import math
import numpy as np
import argparse
import os, glob
from classify import get_net, get_transformer, load_image, load_mean_sub_image, read_labels, forward_pass, \
    feature_extract_forward_pass, create_submission
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV, RandomizedLogisticRegression
import time


# status quo
# print "OS:     ", platform.platform()
# print "Python: ", sys.version.split("\n")[0]
# print "CUDA:   ", subprocess.Popen(["nvcc", "--version"], stdout=subprocess.PIPE).communicate()[0].split("\n")[3]
# print "LMDB:   ", ".".join([str(i) for i in lmdb.version()])


def kfold_fine_tuning(num_folds=5, num_crops=1, num_epoch=5, solver_type='fix'):
    # train model 5 epoch 5 fold cv
    print 'kfold cv fine tuning...'
    logdir = '/home/lab/xryu/code/DistractedDriverDetection/kaggle_ddd_finetune_caffe/fine_tune_googlenet/{0}foldcv/log/{1}crop/'.format(
        num_folds, num_crops)
    for i in range(num_folds):
        print 'fine tuning...fold:{0}/{1}'.format(i + 1, num_folds)

        proc = subprocess.Popen(
            ["/home/lab/wjSun/caffe-master/build/tools/caffe",
             "train",
             "--solver=/home/lab/xryu/code/DistractedDriverDetection/kaggle_ddd_finetune_caffe/fine_tune_googlenet/{0}foldcv/"
             "{1}crop/{2}epoch/solver_{3}_cv_{4}.prototxt".format(num_folds, num_crops, num_epoch, solver_type, i),
             "--weights=/media/lab/labdisk/home/lab1314/xryu/pretrined_model/caffe/googlenet/imagenet_googlenet.caffemodel",
             # " 2>&1 | tee {0}5epoch_cv_{1}_{2}.txt".format(logdir,solver_type,i)
             ],
            stderr=subprocess.PIPE,
            # stdout=subprocess.PIPE
            # stdout=open(logdir + 'subprocess_caffe_{0}epoch_cv_{1}_{2}.txt'.format(num_epoch, solver_type,i), 'w')
        )
        res = proc.communicate()[1]
        '''
        '''
        subprocess.Popen(['sh',
                          '/home/lab/xryu/code/DistractedDriverDetection/kaggle_ddd_finetune_caffe/fine_tune_googlenet/5foldcv/1crop/run_fix_cv.sh'],
                         stderr=subprocess.PIPE)
        # print res
        f = open(logdir + 'subprocess_caffe_{0}epoch_cv_{1}_{2}.txt'.format(num_epoch, solver_type, i), 'w')
        f.write(res)
        f.close()
        print 'fold:{0}/{1} done, log saved'.format(i + 1, num_folds)


def sigmoid(X):
    den = 1.0 + math.e ** (-1.0 * X)
    d = 1.0 / den
    return d


def softmax(W, b, x):
    vec = np.dot(x, W.T)
    vec = np.add(vec, b)
    vec1 = np.exp(vec)
    res = vec1.T / np.sum(vec1, axis=1)
    return res.T


def softmaxPredict(num_classes, input_size, theta_seed, input):
    """
    Parameters
    ----------
    num_classes:10
    input_size:feats dim 10
    theta_seed: 0.005
    input: test data feats here

    Returns
    -------

    """
    """ Reshape 'theta' for ease of computation """
    rand = np.random.RandomState(int(time.time()))
    theta = theta_seed * np.asarray(rand.normal(size=(num_classes * input_size, 1)))

    theta = theta.reshape(input_size, num_classes)

    """ Compute the class probabilities for each example """

    theta_x = np.dot(theta, input)
    hypothesis = np.exp(input)
    probability = hypothesis / np.sum(hypothesis, axis=0)

    """ Give the predictions based on probability values """

    prediction = np.zeros((1, num_classes))
    prediction[:, 0] = np.argmax(probability, axis=0)

    return prediction, probability


def classify_extracted_feature_cls3_fc_ddd(caffemodel, deploy_file,
                                           test_image_files, feature='cls3_fc_ddd',EXTRACT=False,
                                           mean_file=None, labels_file=None, batch_size=None, use_gpu=True, DEBUG=True):
    """
    Classify some images against a Caffe model and print the results

    Arguments:
    caffemodel -- path to a .caffemodel
    deploy_file -- path to a .prototxt
    image_files -- list of paths to images

    Keyword arguments:
    mean_file -- path to a .binaryproto
    labels_file path to a .txt file
    use_gpu -- if True, run inference on the GPU
    """
    # Load the model and images
    net = get_net(caffemodel, deploy_file, use_gpu)
    transformer = get_transformer(deploy_file, mean_file)
    if mean_file is not None:
        _, channels, height, width = transformer.inputs['data']
        mean_file = mean_file
    else:
        channels = 3
        height = 224
        width = 224
        mean_file = False
    if channels == 3:
        mode = 'RGB'
    elif channels == 1:
        mode = 'L'
    else:
        raise ValueError('Invalid number for channels: %s' % channels)
    # img_path = os.path.join(data_root, 'imgs', 'test', '*.jpg')
    print 'extract cross validation train val images features...'
    cache = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/cache'
    TRAIN_IMG_ROOT = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs/train/'
    cv_train_val_txt_path = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/fine-tune-caffe/data_txt/driver_sorted/5foldcv/'
    if EXTRACT:
        for i in range(5):
            print 'fold{0}'.format(i)
            # todo
            train_txt = cv_train_val_txt_path + 'train_list_{0}.txt'.format(i)
            val_txt = cv_train_val_txt_path + 'val_list_{0}.txt'.format(i)
            train_files = open(train_txt).readlines()
            val_files = open(val_txt).readlines()

            print len(train_files)
            print len(val_files)
            #print train_files[0]
            #print val_files[0]
            # print [TRAIN_IMG_ROOT + train_image_file.split(' ')[0] for train_image_file in train_files]
            train_labels = [train_image_file.split(' ')[-1][:-1] for train_image_file in train_files]
            val_labels = [val_image_file.split(' ')[-1][:-1] for val_image_file in val_files]
            print 'train labels', train_labels[0:10]
            train_images = [load_image(TRAIN_IMG_ROOT + train_image_file.split(' ')[0], height, width, mode) for
                            train_image_file in train_files]
            val_images = [load_image(TRAIN_IMG_ROOT + val_image_file.split(' ')[0], height, width, mode) for val_image_file
                          in val_files]

            np.save(cache + '/train_images_224_cv_{0}.npy'.format(i), train_images)
            np.save(cache + '/val_images_224_cv_{0}.npy'.format(i), val_images)
            np.save(cache + '/train_labels_224_cv_{0}.npy'.format(i), train_labels)
            np.save(cache + '/val_labels_224_cv_{0}.npy'.format(i), val_labels)
            train_images_arr = np.asarray(train_images)
            val_images_arr = np.asarray(val_images)
            train_feats = feature_extract_forward_pass(train_images_arr, net, transformer, feature=feature,
                                                       batch_size=batch_size)
            val_feats = feature_extract_forward_pass(val_images_arr, net, transformer, feature=feature,
                                                     batch_size=batch_size)
            np.save(cache + '/googlenet_cls3_fc_ddd_train_feats_cv_{0}.npy'.format(i), train_feats)
            np.save(cache + '/googlenet_cls3_fc_ddd_val_feats_cv_{0}.npy'.format(i), val_feats)
            print 'train feats cv {0} shape'.format(i), train_feats.shape
            print 'val feats cv {0} shape'.format(i), val_feats.shape

    print 'extract test image features...'
    test_img_files = glob.glob(test_image_files + '/*.jpg')
    print 'loading test images...'
    test_images = np.load(cache + '/test_images_224.npy')
    # images = [load_image(image_file, height, width, mode) for image_file in img_files]
    # np.save(cache + '/test_images_224.npy', images)
    test_images_arr = np.asarray(test_images)
    if DEBUG:
        test_images_arr = test_images_arr[:10]
    else:
        test_images_arr = test_images_arr
    print 'test_images', test_images_arr.shape
    labels = read_labels(labels_file)
    if mean_file is None:
        mean_vec = np.array([103.939, 116.779, 123.68], dtype=np.float32)
        mean_vec_rgb = np.array([123.68, 116.779, 103.939], dtype=np.float32)
        reshaped_mean_vec = mean_vec.reshape(3, 1, 1)
        images = test_images_arr - reshaped_mean_vec
        print images.shape, mean_vec.shape, reshaped_mean_vec.shape
    test_id = []
    for fl in test_img_files:
        flbase = os.path.basename(fl)
        test_id.append(flbase)
        # _, test_id = load_test(224, 224, 3)

    # Classify the image
    print 'extracting test feature...'
    test_feats = feature_extract_forward_pass(test_images_arr, net, transformer, feature=feature, batch_size=batch_size)
    np.save(cache + '/googlenet_cls3_fc_ddd_feats.npy', test_feats)
    test_feats = np.load(cache + '/googlenet_icp9_out_feats.npy')
    print 'test feats shape', test_feats.shape

    print 'classifying using extracted feature'
    # todo
    '''
    multiclass logistic regression cv
    '''
    train_features=np.load(cache + '/googlenet_cls3_fc_ddd_train_feats_cv_0.npy')
    truth_train=np.load(cache + '/train_labels_224_cv_0.npy')
    clf = LogisticRegression(penalty='l2', dual=False, tol=0.0001, C=1e5, fit_intercept=True, intercept_scaling=1,
                             class_weight='balanced', random_state=None, solver='lbfgs', max_iter=10000, multi_class='multinomial',
                             verbose=1, warm_start=False, n_jobs=10,
                             #class_weight='balanced', multi_class='multinomial', solver='lbfgs'
                             )
    clf.fit(train_features, truth_train)
    predictions=clf.predict(test_feats)
    probabilities = clf.predict_proba(test_feats)
    print np.shape(probabilities)
    """
    predictions = []
    probabilities = []
    cnt = 0
    for i in range(test_feats.shape[0] / 10000):
        # print '%.5f'%feats[i]
        prediction, probability = softmaxPredict(num_classes=10, input_size=test_feats.shape[0],
                                                 input=test_feats[i])
        predictions.append(prediction)
        probabilities.append(probability)
        cnt += 1
        if cnt % 1000 == 0:
            print cnt, 'predicted'
    """
    np.save(cache + '/googlenet_cls3_fc_ddd_LogisticRegression_predictions.npy', predictions)
    np.save(cache + '/googlenet_cls3_fc_ddd_LogisticRegression_probabilities.npy', probabilities)
    predictions = np.asarray(predictions)
    probabilities = np.asarray(probabilities)
    print 'predictions', predictions.shape, predictions[:2]
    print 'probalities', probabilities.shape, probabilities[:2]

    return probabilities, test_id, test_feats


def classify(caffemodel, deploy_file, image_files,
             mean_file=None, labels_file=None, batch_size=None, use_gpu=True, DEBUG=True):
    """
    Classify some images against a Caffe model and print the results

    Arguments:
    caffemodel -- path to a .caffemodel
    deploy_file -- path to a .prototxt
    image_files -- list of paths to images

    Keyword arguments:
    mean_file -- path to a .binaryproto
    labels_file path to a .txt file
    use_gpu -- if True, run inference on the GPU
    """
    # Load the model and images
    net = get_net(caffemodel, deploy_file, use_gpu)
    transformer = get_transformer(deploy_file, mean_file)
    if mean_file is not None:
        _, channels, height, width = transformer.inputs['data']
        mean_file = mean_file
    else:
        channels = 3
        height = 224
        width = 224
        mean_file = False
    if channels == 3:
        mode = 'RGB'
    elif channels == 1:
        mode = 'L'
    else:
        raise ValueError('Invalid number for channels: %s' % channels)
    # img_path = os.path.join(data_root, 'imgs', 'test', '*.jpg')
    img_files = glob.glob(image_files + '/*.jpg')
    print 'loading images...'
    cache = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/cache'
    if DEBUG:
        if 'mean_sub_sample' in img_files:
            images = [load_mean_sub_image(image_file, height, width, mode) for image_file in img_files]
        else:
            images = [load_image(image_file, height, width, mode) for image_file in img_files]
        images_arr = np.asarray(images)
        # print 'loading image 10 crop'
        # images = [load_image_10crop(image_file, height, width, mode) for image_file in img_files]
    # np.save(cache+'/test_images.npy',images)
    else:
        if 'mean_sub_sample' in img_files:
            images = [load_mean_sub_image(image_file, height, width, mode) for image_file in img_files]
        else:
            images = np.load(cache + '/test_images_224.npy')
            # images = [load_image(image_file, height, width, mode) for image_file in img_files]
            # np.save(cache + '/test_images_224.npy', images)
        images_arr = np.asarray(images)
    labels = read_labels(labels_file)
    if mean_file is None:
        mean_vec = np.array([103.939, 116.779, 123.68], dtype=np.float32)
        mean_vec_rgb = np.array([123.68, 116.779, 103.939], dtype=np.float32)
        reshaped_mean_vec = mean_vec.reshape(3, 1, 1)
        images = images_arr - reshaped_mean_vec
        print images.shape, mean_vec.shape, reshaped_mean_vec.shape
    test_id = []
    for fl in img_files:
        flbase = os.path.basename(fl)
        test_id.append(flbase)
        # _, test_id = load_test(224, 224, 3)

    # Classify the image
    print 'classify...'
    scores = forward_pass(images, net, transformer, batch_size=batch_size)
    # print 'scores shape', scores.shape
    # print 'scores',scores[0:5]
    location = np.argmax(scores, axis=1)
    # print 'calss', location
    cnt = 0
    for i in range(len(scores)):
        # print 'prob', scores[i, location[i]]
        if scores[i, location[i]] > 0.9:
            cnt += 1
    print cnt * 100 / len(scores), '%'
    # print 'scores', scores
    # for i in range(len(img_files)):
    #   print 'predicted clss:', scores[i].argmax()

    # test_res = merge_several_folds_mean(scores, nfolds=1)
    # print np.asarray(test_res).shape
    return scores, test_id


def merge_several_folds_mean(data, nfolds):
    'arithmetic mean 算术平均'
    a = np.array(data[0])
    for i in range(1, nfolds):
        a += np.array(data[i])
    a /= nfolds
    return a.tolist()


def merge_several_folds_geom(data, nfolds):
    'geometric mean 几何平均'
    a = np.array(data[0])
    for i in range(1, nfolds):
        a *= np.array(data[i])
    a = np.power(a, 1 / nfolds)
    return a.tolist()


def merge_several_folds_harm(data, nfolds):
    'harmonic mean 调和平均 make result worse'
    a = np.array(data[0])
    a_rec = 1 / a
    for i in range(1, nfolds):
        a_rec += np.array(1 / data[i])
    h = nfolds / a_rec
    return h.tolist()


def kfold_classification(num_folds=5, num_crops=1, num_epoch=5, merge_folds=5, load_iter=32800, DEBUG=True,
                         solver_type='fix',
                         batch_size=16,
                         use_gpu=True):
    if DEBUG:
        image_files = sample_img
    else:
        image_files = TEST_IMG_ROOT
    kfold_scores = []
    info_strings = []
    # for i in [0,1,3,4]:
    for i in range(merge_folds):
        print 'classfy {0} fold'.format(i + 1)
        model_name = '{0}foldcv/{1}crop/googlenet/{2}_{3}/_iter_{4}.caffemodel'.format(num_folds, num_crops,
                                                                                       solver_type, i,
                                                                                       load_iter)
        # model_name = '{0}foldcv/googlenet_step_{1}/_iter_{2}.caffemodel'.format(num_folds, i, load_iter)
        model_weights = snapshot_root + model_name
        info_string = model_weights.split('/')[-4] + model_weights.split('/')[-3] + '_' + model_weights.split('/')[-2] + \
                      model_weights.split('/')[-1].split('.')[0]
        scores, test_id = classify(caffemodel=model_weights,
                                   deploy_file=model_def,
                                   image_files=image_files,
                                   mean_file=imagenet_mean_path,
                                   labels_file=None,
                                   batch_size=batch_size,
                                   use_gpu=use_gpu,
                                   DEBUG=DEBUG
                                   )
        create_submission(scores, test_id, info_string + '_fold_{0}_test_batch_size_{1}'.format(i + 1, batch_size))
        kfold_scores.append(scores)
        info_strings.append(info_string)
        np.save(
            snapshot_root + '5foldcv/{0}crop/'.format(num_crops) + '5fold_scores_{0}_{1}_iter_{2}.npy'.format(
                solver_type, i,
                load_iter),
            scores)
    np.save(
        snapshot_root + '5foldcv/{0}crop/'.format(num_crops) + '5fold_all_scores_{0}_{1}.npy'.format(solver_type,
                                                                                                     load_iter),
        kfold_scores)
    # bagging
    print 'bagging'
    merged_scores_mean = merge_several_folds_mean(kfold_scores, merge_folds)
    np.save(snapshot_root + '5foldcv/{0}crop/'.format(num_crops) + '5fold_scores_merged_mean_{0}.npy'.format(
        load_iter),
            merged_scores_mean)

    merged_scores_harm = merge_several_folds_harm(kfold_scores, merge_folds)
    np.save(snapshot_root + '5foldcv/{0}crop/'.format(num_crops) + '5fold_scores_merged_geom_{0}.npy'.format(
        load_iter),
            merged_scores_harm)
    print 'creating submission'
    create_submission(merged_scores_mean, test_id,
                      info_strings[0] + '_test_bach_size_{0}_{1}crop_merge_score_mean'.format(batch_size, num_crops))
    create_submission(merged_scores_harm, test_id,
                      info_strings[0] + '_test_batch_size_{0}_{1}crop_merge_score_harm'.format(batch_size,
                                                                                               num_crops))  # it's bad
    print 'done'


snapshot_root = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/fine-tune-caffe/snapshot/'
TEST_IMG_ROOT = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs/test'
sample_img = '/media/wjsun/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs/c4_p072_sample'
mean_sub_sampe = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs/test_mean_substracted_224/bgr'
# googlenet
googlenet_root = '/home/dell/Rui/kaggle_ddd_finetune_caffe/fine_tune_googlenet/'
model_def = googlenet_root + '5foldcv/1crop/googlenet_test.prototxt'
imagenet_mean_path = '/media/wjsun/delldisk/dell/Rui/pretrined_model/caffe/googlenet/imagenet_mean.binaryproto'

if __name__ == '__main__':
    """
    Usage: python kfold_cv_caffe.py -solver_type fix
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('-solver_type', '--solver_type', default='fix',
                        help='solver type: fix, step, sigmoid')
    args = parser.parse_args()
    solver_type = args.solver_type

    print 'k fold cv'
    num_crop = 1
    iters = []
    for i in range(0, 5):
        exit_iter_path = snapshot_root + '5foldcv/1crop/googlenet/fix_{0}'.format(i)
        iters.append(exit_iter_path)
        # print iters[i]
    if os.path.exists(iters[0] + '/_iter_3000.caffemodel'):
        print 'yes'
        kfold_classification(num_folds=5, num_crops=num_crop, num_epoch=5, merge_folds=5, solver_type=solver_type,
                             load_iter=600 * 5 * num_crop,
                             DEBUG=False, batch_size=32)
        kfold_classification(num_folds=5, num_crops=num_crop, num_epoch=5, merge_folds=5, solver_type=solver_type,
                             load_iter=600 * 5 * num_crop,
                             DEBUG=False, batch_size=64)
    else:
        print 'no'



        # kfold_fine_tuning(num_folds=5, num_crops=1, num_epoch=5,solver_type='fix')


        # kfold_classification(num_folds=5, num_crops=1, num_epoch=5, merge_folds=5, solver_type='fix', load_iter=1400 * 5,
        #                     DEBUG=False, batch_size=16)
        # kfold_classification(num_folds=5, num_crops=1, num_epoch=5, merge_folds=5, solver_type='fix', load_iter=1400 * 3,
        #                     DEBUG=False, batch_size=16)
        # kfold_fine_tuning(num_folds=5, num_crops=5, num_epoch=5)
        # kfold_classification(num_folds=5, num_crops=5, num_epoch=5, merge_folds=5, load_iter=7000 * 5, DEBUG=False,
        #                     batch_size=32)

        # kfold_classification(num_folds=5, num_crops=1, merge_folds=5, load_iter=5600, DEBUG=False, batch_size=32)
        # kfold_classification(num_folds=5, num_crops=5, merge_folds=4, load_iter=42000, DEBUG=False, batch_size=32)
        # kfold_classification(num_folds=5, num_crops=1, merge_folds=4, load_iter=7000, DEBUG=False, batch_size=32)
        # http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/hdf5_classification.ipynb
        # or
        # caffe.set_mode_gpu()
        # solver = caffe.get_solver("config.prototxt")
        # solver.solve()
